# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/01/07
# License: MIT License
"""
Tsinghua BCI Lab.
"""
import os
import tarfile
from typing import Union, Optional, Dict, List, cast
from pathlib import Path

import numpy as np
import py7zr
from mne import create_info
from mne.io import RawArray, Raw
from mne.channels import make_standard_montage

from .base import BaseDataset
from ..utils.download import mne_data_path
from ..utils.io import loadmat

# TSINGHUA_URL = 'http://bci.med.tsinghua.edu.cn/download.html'

# 403 error, though it still works
Wang2016_URL = "http://bci.med.tsinghua.edu.cn/upload/yijun/"
# Wang2016_URL = "ftp://sccn.ucsd.edu/pub/ssvep_benchmark_dataset/"
# Wang2016_URL = 'http://www.thubci.com/uploads/down/' # This may work
BETA_URL = "http://bci.med.tsinghua.edu.cn/upload/liubingchuan/"  # 403 error
# Download the unfiltered version of the BETA Database (wof)
# BETA_URL = "https://bci.med.tsinghua.edu.cn/upload/liubingchuan_BETA_wof/"
# BETA_URL = 'https://figshare.com/articles/The_BETA_database/12264401'


class Wang2016(BaseDataset):
    """SSVEP dataset from Yijun Wang.

    This dataset gathered SSVEP-BCI recordings of 35 healthy subjects (17
    females, aged 17-34 years, mean age: 22 years) focusing on 40 characters
    flickering at different frequencies (8-15.8 Hz with an interval of 0.2 Hz).
    For each subject, the experiment consisted of 6 blocks. Each block
    contained 40 trials corresponding to all 40 characters indicated in a
    random order. Each trial started with a visual cue (a red square)
    indicating a target stimulus. The cue appeared for 0.5 s on the screen.
    Subjects were asked to shift their gaze to the target as soon as possible
    within the cue duration. Following the cue offset, all stimuli started to
    flicker on the screen concurrently and lasted 5 s. After stimulus offset,
    the screen was blank for 0.5 s before the next trial began, which allowed
    the subjects to have short breaks between consecutive trials. Each trial
    lasted a total of 6 s. To facilitate visual fixation, a red triangle
    appeared below the flickering target during the stimulation period.
    In each block, subjects were asked to avoid eye blinks during the
    stimulation period. To avoid visual fatigue, there was a rest for several
    minutes between two consecutive blocks. EEG data were acquired using a
    Synamps2 system (Neuroscan, Inc.) with a sampling rate of 1000 Hz.
    The amplifier frequency passband ranged from 0.15 Hz to 200 Hz. Sixty-four
    channels covered the whole scalp of the subject and were aligned according
    to the international 10-20 system. The ground was placed on midway between
    Fz and FPz. The reference was located on the vertex. Electrode impedances
    were kept below 10 KΩ. To remove the common power-line noise, a notch
    filter at 50 Hz was applied in data recording. Event triggers generated by
    the computer to the amplifier and recorded on an event channel
    synchronized to the EEG data.

    The continuous EEG data was segmented into 6 s epochs (500 ms pre-stimulus,
    5.5 s post-stimulus onset). The epochs were subsequently downsampled to
    250 Hz. Thus each trial consisted of 1500 time points. Finally, these data
    were stored as double-precision floating-point values in MATLAB and were
    named as subject indices (i.e., S01.mat, …, S35.mat). For each file, the
    data loaded in MATLAB generate a 4-D matrix named ‘data’ with dimensions
    of [64, 1500, 40, 6]. The four dimensions indicate ‘Electrode index’,
    ‘Time points’, ‘Target index’, and ‘Block index’. The electrode positions
    were saved in a ‘64-channels.loc’ file. Six trials were available for each
    SSVEP frequency. Frequency and phase values for the 40 target indices were
    saved in a ‘Freq_Phase.mat’ file.

    Information for all subjects was listed in a ‘Sub_info.txt’ file. For each
    subject, there are five factors including ‘Subject Index’, ‘Gender’, ‘Age‘
    ’Handedness’, and ‘Group’. Subjects were divided into an ‘experienced’
    group (eight subjects, S01-S08) and a ‘naive’ group (27 subjects, S09-S35)
    according to their experience in SSVEP-based BCIs.

    Frequency Table
    8    9   10   11   12   13   14   15
    8.2  9.2 10.2 11.2 12.2 13.2 14.2 15.2
    8.4  9.4 10.4 11.4 12.4 13.4 14.4 15.4
    8.6  9.6 10.6 11.6 12.6 13.6 14.6 15.6
    8.8  9.8 10.8 11.8 12.8 13.8 14.8 15.8

    Notes
    -----
    1. sub5 is not available from the download url.
    """

    _CHANNELS = [
        "FP1", "FPZ", "FP2", "AF3", "AF4", "F7", "F5", "F3", "F1", "FZ", "F2",
        "F4", "F6", "F8", "FT7", "FC5", "FC3", "FC1", "FCZ", "FC2", "FC4",
        "FC6", "FT8", "T7", "C5", "C3", "C1", "CZ", "C2", "C4", "C6", "T8",
        "TP7", "CP5", "CP3", "CP1", "CPZ", "CP2", "CP4", "CP6", "TP8", "P7",
        "P5", "P3", "P1", "PZ", "P2", "P4", "P6", "P8", "PO7", "PO5", "PO3",
        "POZ", "PO4", "PO6", "PO8", "O1", "OZ", "O2",
    ]

    _FREQS = [
        8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
        8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2,
        8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4,
        8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6,
        8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8,
    ]

    _PHASES = [
        0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0, 1.5,
        0.5, 1.0, 1.5, 0.0, 0.5, 1.0, 1.5, 0.0,
        1.0, 1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5,
        1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0,
        0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0, 1.5,
    ]

    _EVENTS = {str(freq): (i + 1, (0, 5)) for i, freq in enumerate(_FREQS)}

    def __init__(self):
        super().__init__(
            dataset_code="wang2016",
            subjects=list(range(1, 36)),
            events=self._EVENTS,
            channels=self._CHANNELS,
            srate=250,
            paradigm="ssvep",
        )

    def data_path(
        self,
        subject: Union[str, int],
        path: Optional[Union[str, Path]] = None,
        force_update: bool = False,
        update_path: Optional[bool] = None,
        proxies: Optional[Dict[str, str]] = None,
        verbose: Optional[Union[bool, str, int]] = None,
    ) -> List[List[Union[str, Path]]]:
        if subject not in self.subjects:
            raise ValueError("Invalid subject id")

        subject = cast(int, subject)
        url = "{:s}S{:d}.mat.7z".format(Wang2016_URL, subject)
        file_dest = mne_data_path(
            url,
            "tsinghua",
            path=path,
            proxies=proxies,
            force_update=force_update,
            update_path=update_path,
        )

        subject_file = file_dest[:-3]
        if not os.path.exists(subject_file):
            # decompression the data
            with py7zr.SevenZipFile(file_dest, "r") as archive:
                archive.extractall(path=Path(file_dest).parent)
        dests = [[subject_file]]
        return dests

    def _get_single_subject_data(
        self, subject: Union[str, int], verbose: Optional[Union[bool, str, int]] = None
    ) -> Dict[str, Dict[str, Raw]]:
        dests = self.data_path(subject)
        raw_mat = loadmat(dests[0][0])
        epoch_data = raw_mat["data"] * 1e-6
        stim = np.zeros((1, *epoch_data.shape[1:]))
        # insert event label at stimulus-onset
        # 0.5s latency
        stim[0, 125] = np.tile(
            np.arange(1, 41)[:, np.newaxis], (1, epoch_data.shape[-1])
        )
        epoch_data = np.concatenate((epoch_data, stim), axis=0)
        data = np.transpose(epoch_data, (0, 3, 2, 1))

        montage = make_standard_montage("standard_1005")
        montage.rename_channels(
            {ch_name: ch_name.upper() for ch_name in montage.ch_names}
        )
        ch_names = [ch_name.upper() for ch_name in self._CHANNELS]
        ch_names.insert(32, "M1")
        ch_names.insert(42, "M2")
        ch_names.insert(59, "CB1")
        ch_names = ch_names + ["CB2", "STI 014"]
        ch_types = ["eeg"] * 65
        ch_types[59] = "misc"
        ch_types[63] = "misc"
        ch_types[-1] = "stim"

        info = create_info(ch_names=ch_names,
                           ch_types=ch_types, sfreq=self.srate)

        runs = dict()
        for i in range(data.shape[1]):
            raw_data = np.reshape(data[:, i, ...], (data.shape[0], -1))
            raw = RawArray(data=raw_data, info=info)
            raw.set_montage(montage)
            runs["run_{:d}".format(i)] = raw

        sess = {"session_0": runs}
        return sess

    def get_freq(self, event: str):
        return self._FREQS[self._EVENTS[event][0] - 1]

    def get_phase(self, event: str):
        return self._PHASES[self._EVENTS[event][0] - 1]


class BETA(BaseDataset):
    """BETA SSVEP dataset [1]_.

    EEG data after preprocessing are store as a 4-way tensor, with a dimension
    of channel x time point x block x condition. Each trial comprises 0.5-s
    data before the event onset and 0.5-s data after the time window of 2 s or
    3 s. For S1-S15, the time window is 2 s and the trial length is 3 s,
    whereas for S16-S70 the time window is 3 s and the trial length is 4 s.
    Additional details about the channel and condition information can be
    found in the following supplementary information.

    Eight supplementary information is comprised of personal information,
    channel information, frequency and initial phase associated to each
    condition, SNR and sampling rate. The personal information contains age
    and gender of the subject. For the channel information, a location matrix
    (64 x 4) is provided, with the first column indicating channel index,
    the second column and third column indicating the degree and radius in
    polar coordinates, and the last column indicating channel name. The SNR
    information contains the mean narrow-band SNR and wide-band SNR matrix for
    each subject, calculated in (3) and (4), respectively. The initial phase
    is in radius.

    3-100Hz bandpass filtering (eegfilt), downsampled to 250 Hz

    References
    ----------
    .. [1] Liu B, Huang X, Wang Y, et al. BETA: A Large Benchmark Database
    Toward SSVEP-BCI Application[J]. Frontiers in neuroscience, 2020, 14: 627.
    """

    _CHANNELS = [
        "FP1", "FPZ", "FP2", "AF3", "AF4", "F7", "F5", "F3", "F1", "FZ", "F2",
        "F4", "F6", "F8", "FT7", "FC5", "FC3", "FC1", "FCZ", "FC2", "FC4",
        "FC6", "FT8", "T7", "C5", "C3", "C1", "CZ", "C2", "C4", "C6", "T8",
        "TP7", "CP5", "CP3", "CP1", "CPZ", "CP2", "CP4", "CP6", "TP8", "P7",
        "P5", "P3", "P1", "PZ", "P2", "P4", "P6", "P8", "PO7", "PO5", "PO3",
        "POZ", "PO4", "PO6", "PO8", "O1", "OZ", "O2",
    ]

    _FREQS = [
        8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0,
        10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4, 11.6,
        11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13.0, 13.2,
        13.4, 13.6, 13.8, 14.0, 14.2, 14.4, 14.6, 14.8,
        15.0, 15.2, 15.4, 15.6, 15.8, 8.0, 8.2, 8.4,
    ]
    _PHASES = [
        1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0,
        1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0,
        1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0,
        1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0,
        1.5, 0.0, 0.5, 1.0, 1.5, 0.0, 0.5, 1.0,
    ]

    _EVENTS = {str(freq): (i + 1, (0, 2)) for i, freq in enumerate(_FREQS)}

    def __init__(self):
        super().__init__(
            dataset_code="beta",
            subjects=list(range(1, 71)),
            events=self._EVENTS,
            channels=self._CHANNELS,
            srate=250,
            paradigm="ssvep",
        )

    def data_path(
        self,
        subject: Union[str, int],
        path: Optional[Union[str, Path]] = None,
        force_update: bool = False,
        update_path: Optional[bool] = None,
        proxies: Optional[Dict[str, str]] = None,
        verbose: Optional[Union[bool, str, int]] = None,
    ) -> List[List[Union[str, Path]]]:
        if subject not in self.subjects:
            raise ValueError("Invalid subject id")

        subject = cast(int, subject)
        if subject < 11:
            archive_name = "S1-S10.tar.gz"
        elif subject < 21:
            archive_name = "S11-S20.tar.gz"
        elif subject < 31:
            archive_name = "S21-S30.tar.gz"
        elif subject < 41:
            archive_name = "S31-S40.tar.gz"
        elif subject < 51:
            archive_name = "S41-S50.tar.gz"
        elif subject < 61:
            archive_name = "S51-S60.tar.gz"
        else:
            archive_name = "S61-S70.tar.gz"

        url = "{:s}{:s}".format(BETA_URL, archive_name)

        file_dest = mne_data_path(
            url,
            "tsinghua",
            path=path,
            proxies=proxies,
            force_update=force_update,
            update_path=update_path,
        )

        parent_dir = Path(file_dest).parent
        subject_file = os.path.join(parent_dir, "S{:d}.mat".format(subject))

        if not os.path.exists(subject_file):
            # decompression the data
            with tarfile.open(file_dest, "r:gz") as archive:
                archive.extractall(path=parent_dir)
        dests: List[List[Union[str, Path]]] = [[subject_file]]
        return dests

    def _get_single_subject_data(
        self, subject: Union[str, int], verbose: Optional[Union[bool, str, int]] = None
    ) -> Dict[str, Dict[str, Raw]]:
        dests = self.data_path(subject)
        raw_mat = loadmat(dests[0][0])
        epoch_data = raw_mat["data"]["EEG"] * 1e-6
        stim = np.zeros((1, *epoch_data.shape[1:]))
        # 0.5s latency
        stim[0, 125] = np.tile(np.arange(1, 41), (epoch_data.shape[-2], 1))
        epoch_data = np.concatenate((epoch_data, stim), axis=0)
        data = np.transpose(epoch_data, (0, 3, 2, 1))

        montage = make_standard_montage("standard_1005")
        montage.rename_channels(
            {ch_name: ch_name.upper() for ch_name in montage.ch_names}
        )
        ch_names = [ch_name.upper() for ch_name in self._CHANNELS]
        ch_names.insert(32, "M1")
        ch_names.insert(42, "M2")
        ch_names.insert(59, "CB1")
        ch_names = ch_names + ["CB2", "STI 014"]
        ch_types = ["eeg"] * 65
        ch_types[59] = "misc"
        ch_types[63] = "misc"
        ch_types[-1] = "stim"

        info = create_info(ch_names=ch_names,
                           ch_types=ch_types, sfreq=self.srate)

        runs = dict()
        for i in range(data.shape[-2]):
            raw_data = np.reshape(data[..., i, :], (data.shape[0], -1))
            raw = RawArray(data=raw_data, info=info)
            raw.set_montage(montage)
            runs["run_{:d}".format(i)] = raw

        sess = {"session_0": runs}
        return sess

    def get_freq(self, event: str):
        return self._FREQS[self._EVENTS[event][0] - 1]

    def get_phase(self, event: str):
        return self._PHASES[self._EVENTS[event][0] - 1]
